from train_module import * 
from collections import Counter
from UPMSPEnv import * 
from UPMSPModel import *
import copy
import time
os.chdir(os.path.dirname(os.path.abspath(__file__)))
trainer_params = {
    'epochs': 1000,
    'train_pomo_size':  6,
    'num_param': 1, #default
    'num_local': 10,  
}

env_params ={
    'valid_batch_size' : 20,
    'batch_size' : 20,
    'pomo_size': 32, 
    'job_num' : 25,
    'machine_num': 3,
    'mode': 'rand',  #  "fine_tuning"
    'fine_tuning_path' : "model3_500.pt",
    'process_time_params': {
        's_max': 10,
        'T': 0.8,  
        'R': 0.4,
        'm_p': 0.5
    } 
    }

model_params = {
    'input_dim': 6, 
    'embedding_dim':64, 
    'head_num' : 8,
    'encoder_layer_num': 3,
    'latent_cont_dim': 2,
    'latent_disc_dim': 6
}

optimizer_params = {
    'optimizer': {
        'lr': (5e-4),
        'weight_decay': 1e-6, 
    }
}

UPMSP_Trainer1 = UPMSP_Trainer(
                        env_params,
                        model_params,
                        optimizer_params,
                        trainer_params,
                    )


model_dict = f'./result/checkpoint-1000.pt'
UPMSP_Trainer1.model.load_state_dict(torch.load(model_dict)) 
UPMSP_Trainer1.list_1 = {}
UPMSP_Trainer1.model.eval()
env_params2 = copy.deepcopy(UPMSP_Trainer1.env_params)

def load_data4(m, n, t, r, num, pomo_size):
    data = {
        'p': [],
        'w': [], 
        'setup_times': [],
        'due_dates': [],
        'release_times': [],  
        'machine_eligibility': [],  
        'initial_setup': [] 
    }

    file_path = f'./Test_data/machines_{m}/jobs_{n}/s_max_{10}/T_{t}/R_{r}/{num}.json'
    with open(file_path, 'rb') as f:
        data_tmp = json.load(f)

        for _ in range(pomo_size):
            data['p'].append(data_tmp['p']) 
            data['w'].append(data_tmp['w'])
            data['setup_times'].append(data_tmp['setup_times']) 
            data['due_dates'].append(data_tmp['due_dates']) 
            data['release_times'].append(data_tmp['release_times'])
            data['machine_eligibility'].append(data_tmp['machine_eligibility'])
            data['initial_setup'].append(data_tmp['initial_setup'])
    return data


pomo_size = 6
batch_size = 1
list2=[]

total_time=0
total = list()
index_counts = Counter()
with torch.no_grad():
    for m in [3]: ########################################################### Machine size change
        for n in [50]: ########################################################### Job size change
            for T in [0.2, 0.4, 0.6, 0.8, 1]: 
                for R in [0.2, 0.4, 0.6, 0.8, 1]:
                    print(f'Setting parameter: T_{T}, R_{R}')
                    for num in range(20):
                        env_params2['job_num'] = n 
                        env_params2['machine_num'] = m 
                        data = load_data4(m, n, T, R, num, pomo_size)  # 파라미터 예시로 m=3, n=50, t=10, r=0.2 지정

                        start_t = time.time()
                        env_1 = Parallel_machine_tardiness(env_params2)
                        env_1._reset(data, -1)
                        latent_c_var = torch.empty(batch_size, pomo_size, 2).uniform_(-1, 1)

                        latent_d_var = torch.zeros((batch_size, pomo_size, 6), dtype=torch.float32)
                        one_hot_idx = torch.randint(0, 6, (batch_size, pomo_size), dtype=torch.long)
                        latent_d_var[torch.arange(batch_size).unsqueeze(1), torch.arange(pomo_size,).unsqueeze(0), one_hot_idx] = 1

                        latent_var = torch.cat([latent_d_var, latent_c_var], dim=-1)
                        latent_var = latent_var.reshape(batch_size*pomo_size,-1).to('cuda')
                        done = False

                        selected_list = torch.zeros(size=(pomo_size, 0), dtype=torch.long).to('cuda')
                        s = env_1._get_state()

                        while done == False:
                            selected, _  = UPMSP_Trainer1.model.get_action(s, latent_var)
                            selected_list = torch.cat((selected_list, selected.view(pomo_size,1)), dim=1)
                            s, r, done = env_1._step(selected)

                        end_t = time.time()
                        list2.append(torch.max(r).item())
                        model_time = end_t-start_t
                        total_time+=model_time
                    UPMSP_Trainer1.list_1[(m,n,T,R)] = -sum(list2)/len(list2)
                    print("Score : ", -sum(list2)/len(list2))
                    total.append(-sum(list2)/len(list2))
                    list2=list()
    print(UPMSP_Trainer1.list_1)
    print("Average: ",  sum(total)/len(total))
    print("Inference time: ", total_time)